#!/bin/python
#coding:UTF-8
'''
Date:20160426
@author: zhaozhiyong
'''

import math
import sys
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift, estimate_bandwidth

class meanShift():
    def __init__(self):
        self.MIN_DISTANCE = 0.000001#mini error
        self.miniDisappearDist = 100

    def gaussian_kernel(self,distance, bandwidth):
        m = np.shape(distance)[0]
        right = np.mat(np.zeros((m, 1)))
        for i in range(m):
            right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)
            right[i, 0] = np.exp(right[i, 0])
        left = 1 / (bandwidth * math.sqrt(2 * math.pi))

        gaussian_val = left * right
        return gaussian_val

    def shift_point(self,point, points, kernel_bandwidth):
        points = np.mat(points)
        m,n = np.shape(points)
        #计算距离
        point_distances = np.mat(np.zeros((m,1)))
        for i in range(m):
            point_distances[i, 0] = np.sqrt((point - points[i]) * (point - points[i]).T)

        #计算高斯核      
        point_weights = self.gaussian_kernel(point_distances, kernel_bandwidth)

        #计算分母
        all = 0.0
        for i in range(m):
            all += point_weights[i, 0]
        print("all:",all," m:",m)
        #均值偏移
        point_shifted = point_weights.T * points / all
        return point_shifted

    def euclidean_dist(self,pointA, pointB):
        #计算pointA和pointB之间的欧式距离
        total = (pointA - pointB) * (pointA - pointB).T
        return math.sqrt(total)

    def distance_to_group(self,point, group):
        min_distance = 10000.0
        for pt in group:
            dist = self.euclidean_dist(point, pt)
            if dist < min_distance:
                min_distance = dist
        return min_distance

    def group_points(self,mean_shift_points):
        group_assignment = []
        m,n = np.shape(mean_shift_points)
        index = 0
        index_dict = {}
        for i in range(m):
            item = []
            for j in range(n):
                item.append(str(("%5.2f" % mean_shift_points[i, j])))

            item_1 = "_".join(item)
            print(item_1)
            if item_1 not in index_dict:
                index_dict[item_1] = index
                index += 1

        for i in range(m):
            item = []
            for j in range(n):
                item.append(str(("%5.2f" % mean_shift_points[i, j])))
                item_1 = "_".join(item)
            group_assignment.append(index_dict[item_1])

        return group_assignment

    def dispaly(self,points,pointB = ''):
        plt.plot(points.T[0],points.T[1], 'b.', label="original data")
        if pointB != '':
            plt.plot(pointB.T[0],pointB.T[1], 'r.', label="data")           
        plt.title('Mean Shift')
        plt.legend(loc="upper right")
        plt.show()

    def train_mean_shift(self,points, kenel_bandwidth=2):
        #shift_points = np.array(points)
        mean_shift_points = np.mat(points)
        print(type(mean_shift_points))
        print(mean_shift_points.shape)
        max_min_dist = 1
        iter = 0
        m, n = np.shape(mean_shift_points)
        need_shift = [True] * m

        #cal the mean shift vector
        while max_min_dist > self.MIN_DISTANCE:
            max_min_dist = 0
            iter += 1
            print("iter : " + str(iter))
            print(points)
            self.dispaly(mean_shift_points,points)
            for i in range(0, m):
                #判断每一个样本点是否需要计算偏置均值
                if not need_shift[i]:
                    continue
                p_new = mean_shift_points[i]
                p_new_start = p_new
                p_new = self.shift_point(p_new, points, kenel_bandwidth)
                dist = self.euclidean_dist(p_new, p_new_start)

                if dist > max_min_dist:#record the max in all points
                    max_min_dist = dist
                if dist < self.MIN_DISTANCE:#no need to move
                    need_shift[i] = False

                mean_shift_points[i] = p_new

                #________清理过小距离的点
                
                distMat = np.matmul(mean_shift_points,mean_shift_points.T)
                clearMat = distMat < self.miniDisappearDist
                print('group:%s shape:%s'%(mean_shift_points,mean_shift_points.shape)) 
                group = []

                m=mean_shift_points.shape[0]
                a = np.arange(0,m,1)
                for i in range(m):
                    for j in range(i+1,m):
                        clearMat[j,i] = False
                    for j in range(0,i):
                        if clearMat[j,i]:
                            clearMat[i,i] = False
                print('group:%s shape:%s'%(clearMat,clearMat.shape))  

                for i in range(m):
                    test = clearMat[i]
                    if clearMat[i].any():
                        group.append(a[clearMat[m]])

                group = np.array(group)
                print('group:%s shape:%s'%(group,group.shape))               



                print('pointA:',mean_shift_points.shape)


        #计算最终的group
        group = self.group_points(mean_shift_points)

        return np.mat(points), mean_shift_points, group

#https://blog.csdn.net/google19890102/article/details/51030884

data = [[ 406,  405,  311,  312,  319,  320,  233,  232,  321,  405,  322,  404,
   229,  228,  316,  316],
 [1066, 1068, 1231, 1230, 1028, 1026, 1190, 1191, 1028, 1066, 1027, 1067,
  1187, 1189, 1234, 1232]]

data = np.array(data)

points, shift_points, cluster = meanShift().train_mean_shift(data.T,200)
print(points,"___________")
print(shift_points,"___________")
print(cluster,"__________")

print('shape of point:%s shape of center:%s'%(points.shape,shift_points.shape))
plt.plot(points, 'b.', label="original data")
plt.plot(shift_points, 'r.', label="center data")
plt.title('Mean Shift')
plt.legend(loc="upper right")
#plt.show()